
transcripts_with_embeddings <- read_csv("data/cleaned/transcripts_with_embeddings.csv") %>%
  mutate(speech_english = quote) %>%
  filter(!(who_speaking %in% c("Lead", "Other")))

set.seed(123)

transcripts_with_embeddings %>% glimpse


transcript_non_dup_ids <- transcripts_with_embeddings %>%
  mutate(speech_english = quote) %>%
  group_by(speech_english) %>%
  mutate(id_within_dups = row_number()) %>%
  ungroup %>%
  mutate(row_number = row_number()) %>%
  filter(id_within_dups == 1) %>%
  pull(row_number)

as.numeric.matrix <- function(x) {
  matrix(as.numeric(x), nrow = nrow(x), ncol = ncol(x))
}

# Assume 'embeddings' is your matrix of embeddings
# Each row is an embedding vector
embeddings_char <- transcripts_with_embeddings$embedding %>% str_split_fixed(pattern = ";", n = 1536)
embeddings <- matrix(as.numeric(embeddings_char), nrow = nrow(embeddings_char), ncol = ncol(embeddings_char))

embeddings_no_dups <- embeddings[transcript_non_dup_ids, ]

# Use k-means clustering to categorise utterances ----------------------------------------------------------------------------------------------------------------

# Step 1: Project embeddings to 2D using t-SNE
tsne_result <- Rtsne(embeddings_no_dups, dims = 2, perplexity = 30, verbose = TRUE, max_iter = 500)

unique_pca_results <- tibble(
  row_number = transcript_non_dup_ids,
  pca_1 = tsne_result$Y[, 1],
  pca_2 = tsne_result$Y[, 2]
)

# Merge back in to main data
quotes_for_pca <- transcripts_with_embeddings %>%
  ungroup %>%
  mutate(row_number = row_number()) %>%
  tidylog::inner_join(unique_pca_results, by = "row_number") %>%
  select(speech_english, pca_1, pca_2)

transcripts_with_pca <- transcripts_with_embeddings %>%
  tidylog::left_join(quotes_for_pca, by = "speech_english")


# Step 2: Perform k-means clustering
rerun_kmeans <- TRUE  # set to TRUE to rerun kmeans

if (rerun_kmeans) {

  n_clusters <- 20  # You can adjust this number
  set.seed(123)
  kmeans_result <- kmeans(embeddings, centers = n_clusters, nstart = 100) # nstart = 100 means it tries 100 times, and picks the best result

  save(kmeans_result, file = "data/cleaned/transcripts_with_embeddings_kmeans_result.RData")
} else {
  load("data/cleaned/transcripts_with_embeddings_kmeans_result.RData")
}

transcripts_with_kmeans <- transcripts_with_pca %>%
  mutate(cluster = kmeans_result$cluster)


# Step 3: Visualize the results
ggplot(transcripts_with_kmeans, aes(x = pca_1, y = pca_2, color = factor(cluster))) +
  geom_point(alpha = 0.4) +
  theme_minimal() +
  labs(title = "Embeddings Visualization with Cluster Labels",
       x = "t-SNE feature 1",
       y = "t-SNE feature 2")

ggplot(transcripts_with_kmeans, aes(x = pca_1, y = pca_2, color = factor(pair_includes_trans))) +
  geom_point(alpha = 0.2) +
  theme_minimal() +
  labs(title = "Embeddings Visualization with Cluster Labels",
       x = "t-SNE feature 1",
       y = "t-SNE feature 2")

# Get representative quotes from each cluster ----------------------------------------------------------------------------------------------------------------


# Alternatively, using apply()
matrix_to_list <- function(mat) {
    apply(mat, 1, function(x) x, simplify = FALSE)
}

# 1. Calculate the mean embedding for each cluster
cluster_means <- tibble(
  cluster = 1:nrow(kmeans_result$centers),
  cluster_mean_embedding = kmeans_result$centers %>% matrix_to_list()
)

# Convert embeddings matrix into a list column on the main data frame, so each row of the embeddings matrix becomes a vector stored in a list column
transcripts_with_kmeans_tbl <- transcripts_with_kmeans %>%
  mutate(
    embedding = transcripts_with_embeddings$embedding %>% str_split(pattern = ";") %>% map(as.numeric)
  ) %>%
  left_join(cluster_means, by = "cluster")

# Calculate the cosine similarity between each quote and the mean embedding for its cluster
cosine_similarity_paired <- function(embeddings1, embeddings2) {
  # Convert lists to matrices if they aren't already
  if (is.list(embeddings1)) {
    embeddings1 <- do.call(rbind, embeddings1)
  }
  if (is.list(embeddings2)) {
    embeddings2 <- do.call(rbind, embeddings2)
  }

  # Calculate dot products for corresponding pairs
  dot_products <- rowSums(embeddings1 * embeddings2)

  # Calculate magnitudes
  norm1 <- sqrt(rowSums(embeddings1^2))
  norm2 <- sqrt(rowSums(embeddings2^2))

  # Calculate similarities
  similarities <- dot_products / (norm1 * norm2)

  return(similarities)
}


transcripts_kmeans_cosine <- transcripts_with_kmeans_tbl %>%
  mutate(
    similarity_to_cluster_mean = cosine_similarity_paired(embedding, cluster_mean_embedding)
  )

transcripts_max_similarity <- transcripts_kmeans_cosine %>%

# Remove ending punctuation
mutate(speech_english = str_remove(speech_english, "[[:punct:]]+$")) %>% 

  # Remove duplicates
  distinct(speech_english, .keep_all = TRUE) %>%

  group_by(cluster) %>%
  arrange(desc(similarity_to_cluster_mean)) %>%
  dplyr::slice_max(similarity_to_cluster_mean, n = 3) %>%
  ungroup %>%
  relocate(cluster, similarity_to_cluster_mean, quote)


# Which clusters are predictive of choosing trans ----------------------------------------------------------------------------------------------------------------

p_cluster <- transcripts_with_kmeans_tbl %>%
  select(group_id, cluster, speech_english, pair_includes_trans) %>%
  group_by(group_id, pair_includes_trans) %>%
  mutate(
    n_utterances = n()
  ) %>%
  mutate(pair_includes_trans = as.integer(pair_includes_trans)) %>%
  group_by(group_id, pair_includes_trans, cluster, n_utterances) %>%
  summarise(
    n_cluster = n()
  ) %>%
  mutate(p_cluster = n_cluster / n_utterances) %>%
  ungroup %>%
  print

p_cluster_wide <- p_cluster %>%
    select(-n_cluster) %>%
    pivot_wider(names_from = cluster, names_prefix = "cluster_", values_from = p_cluster, values_fill = 0) %>%
    ungroup()

# Merge wtih r2
r2_with_transcript_clusters <- r2_choices_num %>%
  left_join(p_cluster_wide %>% filter(pair_includes_trans == 1), by = "group_id") %>% 
  mutate(across(matches("^cluster_\\d+"), z_calc_std))  

# How common is each cluster? ----------------------------------------------------------------------------------------------------------------

# HOW COMMON?
p_cluster %>%
  complete(cluster = 1:n_clusters,
           group_id,
           pair_includes_trans,
           fill = list(p_cluster = 0)) %>%
  mutate(pair_includes_trans = factor(pair_includes_trans)) %>% mutate(cluster = factor(cluster)) %>% bar_chart(x = cluster, y = p_cluster, fill = pair_includes_trans)


# Output table ----------------------------------------------------------------------------------------------------------------

# Representative quotes for each cluster, one line per cluster
cluster_quotes <- transcripts_max_similarity %>% group_by(cluster) %>% summarise(quotes = paste0(paste0("'", quote, "'"), collapse = "; "))

# Frequency by trans / non-trans
p_cluster_by_trans <- p_cluster %>%
  group_by(pair_includes_trans, cluster) %>%
  summarise(n_cluster = sum_na(n_cluster)) %>%
  group_by(pair_includes_trans) %>%
  mutate(p_cluster = n_cluster / sum_na(n_cluster)) %>%
  select(-n_cluster) %>%
  pivot_wider(names_from = pair_includes_trans, values_from = p_cluster, names_prefix = "p_cluster")

# Is it predictive (bivariates OLS)
coeff_clusters <- paste0("cluster_", 1:n_clusters) %>%
  map(
    ~ feols_custom(
      as.formula(
        paste0("r2_choose_trans ~ r2_reliability_diff * r2_reliability_shown + item_diff + ", .x)
      ),
        data = r2_with_transcript_clusters,
      fixef = c("stratum_id", "video_type", "delivery_incentive_exp", "comparator_order_in_pair", "phase")
    )
  ) %>%
  map(tidy_90) %>%
  bind_rows() %>%
  filter(term %in% paste0("cluster_", 1:n_clusters)) %>%
  mutate(cluster = str_remove(term, "cluster_") %>% as.integer)

cluster_table <- tibble(
  cluster = 1:n_clusters
) %>%
    left_join(cluster_quotes, by = "cluster") %>%
    left_join(p_cluster_by_trans, by = "cluster") %>%
    left_join(coeff_clusters, by = "cluster") %>%
  select(cluster_id = cluster, quotes, p_non_trans = p_cluster0, p_trans = p_cluster1, coeff = estimate, p_value = p.value)

# Helper function to convert \n to makecell
convert_to_makecell <- function(x) {
  if (str_detect(x, "\n")) {
    paste0("\\makecell{", str_replace_all(x, "\n", "\\\\\\\\"), "}")
  } else {
    x
  }
}

cluster_table %>%
  # Multiply proportions by 100 for percentage display
  mutate(
    p_non_trans = p_non_trans * 100,
    p_trans = p_trans * 100
  ) %>%
  mutate(coeff = coeff * 100) %>%
  # Add significance stars
  arrange(p_value) %>%
  mutate(p_value = q_val(p_value)) %>% 
  mutate(cluster_id = row_number()) %>%
  mutate(
    p_stars = case_when(
      p_value < 0.01 ~ "***",
      p_value < 0.05 ~ "**",
      p_value < 0.1 ~ "*",
      TRUE ~ ""
    ),
    p_value_stars = paste0(format(round(p_value, 2), nsmall = 2), p_stars)
  ) %>%
  # Select and rename columns
  select(
    ID = cluster_id,
    `Representative quotes` = quotes,
    `No trans` = p_non_trans,
    `Includes trans` = p_trans,
    `$\\beta$ (p.p.)` = coeff,
    `q-value` = p_value_stars
  ) %>%
  # Convert column names with \n to makecell format
  setNames(map_chr(names(.), convert_to_makecell)) %>%
  kable(
    format = "latex",
    digits = c(0, 0, 1, 1, 1, 0),
    booktabs = TRUE,
    linesep = "",
    escape = FALSE,
    align = c("c", "l", "c", "c", "c", "c")
  ) %>%
  kable_styling() %>% 
  add_header_above(c(" " = 2, 
                    "\\\\% sentences in discussion" = 2, 
                    "\\\\makecell{Association between \\\\\\\\ \\\\% sentences in trans discussions and \\\\\\\\P(chose trans in outcome round)}" = 2),
                    escape = FALSE) %>%
  kableExtra::column_spec(column = 2, width = "4in") %>%
  kable_remove_table() %>%
  writeLines("outputs/tables/transcript_cluster_quotes.tex")

# Output p-value for cluster 1
cluster_table %>% 
filter(cluster_id == 20) %>% 
  mutate(q_value = q_val(p_value)) %>% 
pull(q_value) %>% 
write_stat("outputs/stats/pval_transcript_cluster_1.tex", digits = 3, p_value = TRUE)

# Output coefficient for cluster 1
cluster_table %>% 
filter(cluster_id == 20) %>% 
pull(coeff) %>% 
times_100 %>% 
write_stat("outputs/stats/coeff_transcript_cluster_1.tex", digits = 1)

# P-value for cluster 2
cluster_table %>% 
filter(cluster_id == 2) %>% 
mutate(q_value = q_val(p_value)) %>% 
pull(q_value) %>% 
write_stat("outputs/stats/pval_transcript_cluster_2.tex", digits = 3, p_value = TRUE)

# Coefficient for cluster 2
cluster_table %>% 
filter(cluster_id == 2) %>% 
pull(coeff) %>% 
times_100 %>% 
write_stat("outputs/stats/coeff_transcript_cluster_2.tex", digits = 1)  


# P-value for cluster 2
cluster_table %>% 
filter(cluster_id == 12) %>% 
mutate(q_value = q_val(p_value)) %>% 
pull(q_value) %>% 
write_stat("outputs/stats/pval_transcript_cluster_12.tex", digits = 3, p_value = TRUE)

# Coefficient for cluster 2
cluster_table %>% 
filter(cluster_id == 12) %>% 
pull(coeff) %>% 
times_100 %>% 
write_stat("outputs/stats/coeff_transcript_cluster_12.tex", digits = 1)  


# Which arguemnts are used for anti-trans? ----------------------------------------------------------------------------------------------------------------

discuss_obs %>% glimpse
transcripts_with_embeddings %>% glimpse

# Helper functions
string_to_vector <- function(embedding_string) {
  as.numeric(strsplit(embedding_string, ";")[[1]])
}

# Custom cosine similarity function that works with a vector and matrix
cosine_similarity_custom <- function(vec, mat) {
  # Calculate dot product
  dot_products <- mat %*% vec

  # Calculate magnitudes
  vec_magnitude <- sqrt(sum(vec^2))
  mat_magnitudes <- sqrt(rowSums(mat^2))

  # Calculate cosine similarity
  similarities <- as.vector(dot_products / (vec_magnitude * mat_magnitudes))

  return(similarities)
}

# Main analysis
# Convert string embeddings to list of numeric vectors
embeddings_list <- transcripts_with_embeddings$embedding %>%
  map(string_to_vector)

# Convert to matrix for easier computation
embeddings_matrix <- do.call(rbind, embeddings_list)

# Calculate mean embeddings for each group
pro_trans_embeddings <- embeddings_matrix[transcripts_with_embeddings$pro_trans_arg == TRUE, ]
non_pro_trans_embeddings <- embeddings_matrix[transcripts_with_embeddings$pro_trans_arg == FALSE, ]

mean_embedding_pro <- colMeans(pro_trans_embeddings, na.rm = TRUE)
mean_embedding_non <- colMeans(non_pro_trans_embeddings, na.rm = TRUE)

# Calculate similarities
similarities_pro <- cosine_similarity_custom(mean_embedding_pro, embeddings_matrix)
similarities_non <- cosine_similarity_custom(mean_embedding_non, embeddings_matrix)

# Add results to dataframe
transcripts_pro_anti <- transcripts_with_embeddings %>%
  mutate(
    similarity_pro_trans = similarities_pro,
    similarity_non_pro_trans = similarities_non,
    similarity_diff_pro_anti = similarity_pro_trans - similarity_non_pro_trans
  )

transcripts_pro_anti %>%
  arrange(similarity_diff_pro_anti) %>%
  select(quote, similarity_diff_pro_anti, reasons_joined) %>%
  print(n = 100)


# Look at the most anti-trans arguments ----------------------------------------------------------------------------------------------------------------

transcripts_pro_anti %>%
  filter(pro_trans_arg == FALSE) %>%
  select(quote, similarity_diff_pro_anti, reasons_joined) %>%
  arrange(similarity_diff_pro_anti) %>%
  view()

discuss_obs %>% glimpse

transcripts_with_discuss_obs <- transcripts_pro_anti %>%
  tidylog::left_join(
    discuss_obs %>% select(-any_of(names(transcripts_pro_anti)), group_id) %>% mutate(round = as.numeric(round)),
    by = c("group_id", "pair_id" = "round")
  )


transcripts_with_discuss_obs %>%
  filter(neg_mentions == TRUE) %>%
  group_by(group_id, pair_id) %>%
  filter(any(str_detect(reasons_joined, "Easy to talk with this person|Appearance of person|Other person looks indecent|Age"))) %>%
  select(group_id, quote, pro_trans_arg, similarity_diff_pro_anti, reasons_joined) %>%
  view()

# Look at every demographic and see correlation with reason ----------------------------------------------------------------------------------------------------------------

group_relations %>% glimpse

# Is there a correlation between group characteristics and what they say?
group_demo <- df %>%
  left_join(relations_scores_group, by = "group_id") %>%
  select(group_id, all_of(unname(demo_vars)), relation_score_fact_z_group) %>%
  group_by(group_id) %>%
  summarise(across(everything(), mean_na)) %>%
  mutate(
    across(-group_id, z_calc_std)
  )

transcripts_for_demo <- transcripts_pro_anti %>%
  left_join(group_demo, by = "group_id")


all_demo_reason_effects <- tidyr::crossing(
  demo_var = names(group_demo)[2:ncol(group_demo)],
  reason_type = transcripts_by_reason %>% mutate(reasons_joined = as.character(reasons_joined)) %>% filter(n > 10) %>% pull(reasons_joined) %>% unique()
) %>%
  rowwise() %>%
  mutate(
    data = list(
      transcripts_for_demo %>%
        filter(!is.na(pro_trans_arg)) %>%
        mutate(
          reason_pos = str_detect(reasons_joined, reason_type) & pro_trans_arg,
          reason_neg = str_detect(reasons_joined, reason_type) & !pro_trans_arg
        )
    )
  ) %>%
  ungroup %>%
  mutate(
    model_pos = map2(
      demo_var, data,
      safely(~ fixest::feols(as.formula(paste0("reason_pos ~ ", .x)), data = .y, cluster = "group_id"))
    ),
    model_neg = map2(
      demo_var, data,
      safely(~ fixest::feols(as.formula(paste0("reason_neg ~ ", .x)), data = .y, cluster = "group_id"))
    ),
    model_pos = map(model_pos, ~ .x$result),
    model_neg = map(model_neg, ~ .x$result)
  ) %>%
  mutate(
    coeffs_pos = map(model_pos, safely(tidy_90)) %>% map("result"),
    coeffs_neg = map(model_neg, safely(tidy_90)) %>% map("result"),
    coeffs = map2(coeffs_pos, coeffs_neg, ~ bind_rows("pos" = .x, "neg" = .y, .id = "pos_neg"))
  ) %>%
  select(-c(data:coeffs_neg)) %>%
  unnest(coeffs) %>%
  filter(term != "(Intercept)")


all_demo_reason_effects

# PLOT the effects
all_demo_reason_effects %>%
  filter(!str_detect(term, "group_id")) %>%
  filter(p.value < 0.05) %>%
  filter(!str_detect(reason_type, "Stating preference with no reason")) %>%

  mutate(
    y = fct_cross(term, reason_type, pos_neg) %>% fct_reorder(estimate)
  ) %>%
  ggplot(aes(x = estimate, y = y, color = pos_neg)) +
  geom_point() +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0.2) +
  theme_minimal() +
  labs(title = "Effect of Demographics") +
  geom_vline(xintercept = 0, linetype = "dashed")

ggsave("outputs/figs/demographics_and_reasons.pdf", width = 10, height = 15)


# Amalgamate to the group level so can be used for heterogeneity wtih R2 ----------------------------------------------------------------------------------------------------------------

transcripts_pro_anti %>% count_prop(pro_comparator_arg, pair_includes_trans)

transcripts_reasons <- transcripts_pro_anti %>%
  separate_rows(reasons_joined, sep = "; ") %>%
  mutate(
    reasons_joined = case_when(
      str_detect(reasons_joined, "Person seems reliable") ~ "Person seems reliable etc.",
      TRUE ~ reasons_joined
    )
  ) %>%

  # Get the sum of utterances for each reason x pro/anti, in each group round
  filter(!is.na(pro_comparator_arg)) %>%
  group_by(group_id, pair_id, reasons_joined, pro_comparator_arg, pair_includes_trans) %>%
  count() %>%
  ungroup %>%
  tidyr::complete(group_id, pair_id, reasons_joined, pro_comparator_arg, pair_includes_trans, fill = list(n = 0)) %>%
  arrange(group_id, pair_id, reasons_joined, pro_comparator_arg, pair_includes_trans) %>%

#   Calculate number of group-rounds where each reason was mentioned
  group_by(group_id, reasons_joined, pro_comparator_arg, pair_includes_trans) %>%
  summarise(reason_mentioned_yn = sum_na(n > 0))

r2_with_reasons_transcript <- r2_choices_num %>%
  select(-pair_includes_trans) %>%
  tidylog::left_join(transcripts_reasons %>% rename(pair_includes_trans_in_discussion = pair_includes_trans), by = "group_id") %>%
  group_by(reasons_joined, pro_comparator_arg, pair_includes_trans_in_discussion) %>%
  nest() %>%
  mutate(
    n_instances = map_int(data, ~ sum_na(.x$reason_mentioned_yn)),
    model = map(data,
                safely(
                  ~ fixest::feols(as.formula("r2_choose_trans ~ reason_mentioned_yn"), data = .x, fixef = "stratum_id", cluster = "group_id")
                )) %>%
      map("result")
  )


# r2_with_reasons_transcript

r2_with_reasons_transcript %>% ungroup %>% count_prop(n_instances)

r2_with_reasons_transcript %>%
  filter(n_instances > 300) %>%
  mutate(
    coeffs = map(model, safely(tidy_90)) %>% map("result")
  ) %>%
  unnest(coeffs) %>%
  filter(term != "(Intercept)", reasons_joined != "w") %>%
  arrange(desc(estimate)) %>%

  ggplot(aes(x = estimate, y = reasons_joined, color = pro_comparator_arg)) +
  facet_wrap(~ pair_includes_trans_in_discussion) +
  geom_pointrange(aes(xmin = conf.low, xmax = conf.high), position = position_dodge(0.3), height = 0.2) +
  geom_vline(xintercept = 0, linetype = "dashed")


r2_with_reasons_transcript

# Heterogeneity by what proportion of discussion was just pure statement of position (A vs B) ----------------------------------------------------------------------------------------------------------------

transcripts_no_reasons <- transcripts_pro_anti %>%
  group_by(group_id, pair_includes_trans) %>%
  filter(!is.na(reasons_joined)) %>%
  summarise(
    n_statements = n(),
    n_statements_no_reason = sum(str_detect(reasons_joined, "Stating preference with no reason|Agreeing with previous person"))
  ) %>%
  mutate(
    p_statements_no_reason = n_statements_no_reason / n_statements
  ) %>%
  group_by(group_id) %>%
  summarise(
    p_statements_no_reason_trans = mean_na(p_statements_no_reason[pair_includes_trans==1]),
    p_statements_no_reason_non_trans = mean_na(p_statements_no_reason[pair_includes_trans==0]),
    p_statements_no_reason = mean_na(p_statements_no_reason)
  )


# Now, merge with the main dataset
r2_with_statements <- r2_choices_num %>%
  tidylog::left_join(transcripts_no_reasons, by = "group_id")

# Plot heterogeneity by proportion of statements without reason
r2_with_statements %>%
  filter(p_statements_no_reason_trans < 0.8) %>%
  ggplot(aes(x = 1-p_statements_no_reason_trans, y = r2_choose_trans)) +
  geom_smooth(method = "lm")


# REPRESENTATIVE QUOTES --------------------------------------------------------------------------------------------------------------------------------

transcripts_with_embeddings %>% names()

mean_embedding_0 <- transcripts_with_embeddings$mean_embedding_0 %>% str_split_fixed(pattern = ";", n = 1536) %>%
  as.numeric.matrix()

mean_embedding_1 <- transcripts_with_embeddings$mean_embedding_1 %>% str_split_fixed(pattern = ";", n = 1536) %>%
  as.numeric.matrix()

df %>% count_prop(is_listener)

# TOP QUOTES
transcripts_with_embeddings %>% arrange(desc(similarity_diff)) %>% pull(quote) %>% .[[1]] %>%
  paste0(., "%") %>%
  str_remove("\\.") %>%
  writeLines("outputs/stats/top_quote.tex")

transcripts_with_embeddings %>%  filter(!str_detect(quote, regex("transgender", ignore_case = TRUE))) %>%
  arrange(desc(similarity_diff)) %>%
  pull(quote) %>% .[[1]] %>%
  paste0(., "%") %>%
  str_remove("\\.") %>%
  writeLines("outputs/stats/top_quote_no_trans.tex")